import copy
import os
import random
import shutil
import subprocess

import numpy as np

from calculate_accept_ratio import compliklihood, accept_or_reject, boundary_handling
from tough_script import load_uge_mesh, uge2tough2, run_tough, pro_mesh, main_script, write_string
from scipy.interpolate import interp1d
from seis2frac_funs import genfrasize, singfracgen, loadDataSet, get_rid_init_f, distances, writeDataSet
from seis2frac_funs import truncate_fracture, write_string
from time import time


def swap_ms(a, b):
    if a > b:
        t = b
        b = a
        a = t
    return a, b


cac = os.getcwd()
f_old, cen_old, dip, angle, input_all = [], [], [], [], []
if not os.path.isfile('inversion_param.txt'):
    print('critical file lost, exiting')
    exit(1)
with open('inversion_param.txt') as op:
    for st in op.readlines():
        gb = st.split(':')[1].strip()
        input_all.append(gb)
num_iter = int(input_all[0])  # total number of iterations
seis_param, forward_param, inv_param = input_all[1:6], input_all[6:9], input_all[9:]
real_name = f'{seis_param[2]}.txt'
if not os.path.isfile(real_name):
    print('no valid micro-seismic file found, exiting')
    exit(1)
data = loadDataSet(real_name, ',')
seis_x, seis_y, seis_z = [s[0] for s in data], [s[1] for s in data], [s[2] for s in data]
print('read fracture characterization parameters')
try:
    l0, iters = float(seis_param[0].split(',')[0]), int(seis_param[0].split(',')[1])
    r2, pts = float(seis_param[1].split(',')[1]), int(seis_param[1].split(',')[2])
except ValueError:
    print('input value type do not match')
    exit(1)
sinfe, ssupb = [min(seis_x), min(seis_y), min(seis_z)], [max(seis_x), max(seis_y), max(seis_z)]
print('fracture generation range:')
print(f'x={sinfe[0]}~{ssupb[0]}, length={ssupb[0] - sinfe[0]}')
print(f'y={sinfe[1]}~{ssupb[1]}, length={ssupb[1] - sinfe[1]}')
print(f'z={sinfe[2]}~{ssupb[2]}, length={ssupb[2] - sinfe[2]}')
cur_oc = seis_param[3].split(';')
for u in cur_oc:  # prior orientation information
    ut = u.split('|')
    ucd, ucc = ut[0].split(','), ut[1].split(',')
    if float(ucc[0]) < 0.0 or float(ucc[1]) > 90.0:
        print('dip angle out of range, exiting')
        exit(1)
    ucd, ucc = list(map(int, ucd)), list(map(int, ucc))
    ucd[0], ucd[1] = swap_ms(ucd[0], ucd[1])  # automatically recognize minimum and maximum values
    ucc[0], ucc[1] = swap_ms(ucc[0], ucc[1])
    dip.append(ucd)
    angle.append(ucc)
data_fit, init_unify, d_ini = copy.deepcopy(data), [], 0
try:
    s5 = list(map(int, seis_param[4].split(',')))
    au_well, minf, maxf = s5[0], s5[1], s5[2]
except ValueError:
    print('number of fractures should be integers')
    exit(1)
if au_well != 0:
    write_string('well_allow.txt', 'w')
print('read forward flow simulation and inversion method parameters')
PTs = forward_param[0].strip().split(',')
time_splits = forward_param[1].strip().split(',')
max_dt = [str(float(mdt) / 500) for mdt in time_splits]
max_dt[0] = str(float(max_dt[0]) * 50)
tsg = {'total': time_splits, 'max_step': max_dt}
rbrb = forward_param[2].strip().split(',')
infb, supb = float(rbrb[0]), float(rbrb[1])
print('read inversion method parameters')
u_trans = inv_param[0].strip().split(',')
mol, trans = float(u_trans[0]), float(u_trans[1])  # molecular mass of tracer and unit conversion
p_f_move = inv_param[1].strip().split(',')
for ag in p_f_move:
    if float(ag) >= 1.0 or float(ag) <= 0.0:
        print('error: probability value must vary between 0 and 1')
        exit(1)
acc_m, rej_m = float(p_f_move[0]), float(p_f_move[1])
try:
    inj_c, r_e = inv_param[2], float(inv_param[3])
    if float(inj_c) < 0.0:
        print('error: concentration should be greater than 0')
        exit(1)
    if r_e <= 0.0:
        print('error: sigma error in observed should be greater than 0')
        exit(1)
except ValueError:
    print('error: input must be a float number')
    exit(1)
units, tems = inv_param[4].strip().split(','), inv_param[5].strip().split(',')
inj_con, ini_tem, inj_tem, so_si = [inj_c, '1.000d-10'], float(tems[0]), float(tems[1]), []
if os.path.isfile('source_sink.txt'):
    add_ini = loadDataSet('source_sink.txt', ',')  # all elements are string format
if minf <= len(add_ini):
    minf = len(add_ini) + 1
num, iter2, cum_time, p_old, success, icb = random.randint(minf, maxf), 1, 0.0, None, 0, 0
dfn_gen_inp = f'dfnGen {cac}/gen_4_user_rects.dat'
write_string('4_user_ell_run_file.txt', 'w', dfn_gen_inp)
g11, g12 = open('gen_4_user_rects_mother.dat'), open('get_dir.txt')
strr, dirs = ['t' * 10, 'm' * 10], ['define_4_user_ellipses_3.dat', 'define_4_user_rects.dat']
mot, ad = g11.read().split('satsta'), g12.read()
for ori, rep in zip(strr, dirs):
    ad = ad.replace(ori, f'{cac}/{rep}')
write_string('gen_4_user_rects.dat', 'w', mot[0] + ad + mot[1])
dirs = ['accept', 'reject']
for s in dirs:
    if not os.path.isdir(s):
        os.mkdir(s)
if os.path.isfile('current_fit.txt') and os.path.isfile('next.txt'):  # read restart information
    with open('current_fit.txt') as rs:
        d_b = list(map(float, rs.readline().strip().split(',')))
        d_old, k_old = d_b[0], [d_b[1]]  # distance function and fracture aperture
        frac_info = rs.readline().strip().split(';')  # fracture geometry parameters
        cen_str = rs.readline().strip().split(';')
        eqr_str = rs.readline().strip().split(',')
        eqr_old = list(map(float, eqr_str))
        for u in range(len(frac_info)):
            cen_sub = cen_str[u].strip().split(',')
            cen_old.append(list(map(float, cen_sub)))
            uu = frac_info[u].split(',')
            uv = list(map(float, uu))
            f_old.append(uv)
        bayes = rs.readline().strip().split(',')
        bayes = list(map(float, bayes))
        r_old, log_old = bayes[0], bayes[1]  # RMSE and log-likelihood
    p_old = [log_old, r_old, d_old, f_old, k_old, cen_old, eqr_old]
    with open('next.txt') as g:
        next_params = g.readlines()[2].strip().split(':')[1].split('\t')
    iter2 = int(next_params[0].split('=')[1]) + 1  # present iteration 
    num, success = int(next_params[1].split('=')[1]), int(next_params[3].split('=')[1])
    cum_time = float(next_params[4].split('=')[1])  # cumulated time
observe1 = loadDataSet('obs_c.txt', ',')  # observation data
t_obs, c_obs_con = [ii[0] for ii in observe1], [ii[1] for ii in observe1]
e_con, t_start = r_e * max(c_obs_con), time()
print(f'performing inversions begin at {iter2}, total of {num_iter}')
while iter2 <= num_iter:
    st1 = time()
    if os.path.isfile('next.txt'):
        shutil.copy('next.txt', 'next_bak.txt')  # backup of data for the next run
        write_string('next.txt', 'w')
    print(f'iter {iter2}: fracture generation start with {num} fractures')
    data_fit, init_unify, d_ini, ini_frac, so_si = get_rid_init_f(add_ini, data, l0, dip,
                                                                  angle)  # unfitted points and initial fractures
    geo, para, dist = singfracgen(num, data, data_fit, l0, dip, angle, ini_frac, init_unify, d_ini, iters)
    cens = genfrasize(l0, data, para, r2, pts, sinfe, ssupb)
    if len(cens) < num:
        print(f'iter {iter2}: blank fracture is read, exit forward model')
        if iter2 == 1:
            num = random.randint(minf, maxf)
            continue
        write_string('next.txt', 'a', f'blank fracture is read\nblank record\n')
        num = boundary_handling(num, minf, maxf, acc_m, rej_m, len(p_old[3]), iter2, False)
        iter2 += 1
        icb += 1
        get_time = time() - st1
        cum_time += get_time
        write_string('next.txt', 'a', f'success={success}\ttime={cum_time}')
        continue
    for au in so_si:  # auxillary fractures for sink/sources, not included in mesh generation
        nor_ss, tran_ss = au[6:], au[:3]
        D = - np.dot(nor_ss, tran_ss)
        nor_ss.append(D)
        para.append(nor_ss)
    for gtt in para:
        writeDataSet('prior_fracs.txt', gtt, ',')
    rads = truncate_fracture('define_4_user_ellipses_2.dat', sinfe, ssupb)
    dfn_work = f'{cac}/f'
    subprocess.call(f'python seis2frac_funs.py -name {dfn_work} -input 4_user_ell_run_file.txt -ncpu 1', shell=True)
    os.remove('prior_fracs.txt')
    if os.path.isfile(f'{dfn_work}/full_mesh.uge'):
        shutil.copy('define_4_user_ellipses_2.dat', f'{cac}/f/define_4_user_ellipses_2.dat')
        dfn_mesh = load_uge_mesh(f'{dfn_work}/full_mesh.uge')
        cells, connections = dfn_mesh[1:(int(dfn_mesh[0][1]) + 1)], dfn_mesh[(int(dfn_mesh[0][1]) + 2):]
        dictt, b = f'{cac}/treact', 10 ** random.uniform(infb, supb)
        perm = [f'{b ** 2 / 12:.2e}', '1.00E-22']
        mesh, connection, inj, ext = uge2tough2(cells, connections, so_si, b, para[:-len(so_si)])
        w_c, w_s = pro_mesh(mesh, connection, float(PTs[1]), float(PTs[2]), float(PTs[3]), float(PTs[4]), ref_d=float(PTs[0]))
        write_string(f'{dictt}/MESH', 'w', w_c)
        write_string(f'{dictt}/INCON', 'w', f'{w_s}\n')
        inj_rates = [inj[2] for c in range(len(tsg['total']))]
        fl, ch = main_script(dictt, inj, ext, perm, tsg, inj_rates, inj[3], inj_con)
        t, c_sin_con, TT, P, T = run_tough(dictt, time_splits, fl, ch, units[0], units[1], t_obs, iso=True)
        if t:  # if simulation succeeds
            write_string('next.txt', 'a', f'\tTOUGH2 simulation converged\n')
            g = interp1d(t, c_sin_con, kind='cubic')
            c_con = list(g(t_obs))  # interpolate concentrations at given time points
            rmse_con, log_con = compliklihood(c_con, c_obs_con, mol, trans, e_con)  # compute likelihood and RMSE
            rmse_con /= max(c_obs_con)
            p_new = [log_con, rmse_con, dist, geo, [b], cens, rads]
            if not os.path.isfile('accept.txt'):  # store accepted simulations
                write_string('accept.txt', 'a', 'iter\tRMSE\tloglihood\tb\tdist\tnumf\n')
            if not os.path.isfile('geometry_log.txt'):  # store geometry
                write_string('geometry_log.txt', 'a', 'iters,dist,aperture,occurrence,center,equal_radius\n')
            if not os.path.isfile('inverse_log.txt'):  # store successful forward runs
                write_string('inverse_log.txt', 'a', '# iters,likelihood,q,rmse,numf,accept\n')
            p_new, num = accept_or_reject(p_new, minf, maxf, c_con, acc_m, rej_m, p_old, iter2)
            p_old = copy.deepcopy(p_new)
            success += 1
        else:  # TOUGH2 simulation failed, regard as rejected simulation
            print(f'iter {iter2}: TOUGH2 simulation failed, exit forward model')
            if iter2 == 1:
                num = random.randint(minf, maxf)
                continue
            write_string('next.txt', 'a', f'TOUGH2 simulation diverged\n')
            num = boundary_handling(num, minf, maxf, acc_m, rej_m, len(p_old[3]), iter2, False)
    else:
        print(f'iter {iter2}: fracture rejection occur, exit forward model')
        if iter2 == 1:
            num = random.randint(minf, maxf)
            continue
        write_string('next.txt', 'a', f'FRAM fracture rejection activated\nblank record\n') 
        num = boundary_handling(num, minf, maxf, acc_m, rej_m, len(p_old[3]), iter2, False)
    iter2 += 1
    icb += 1
    get_time = time() - st1
    cum_time += get_time
    print(f'iter {iter2 - 1} takes {get_time} seconds')
    write_string('next.txt', 'a', f'success={success}\ttime={cum_time}')  # write statistic data for inversions
print(f'successful rate: {success}/{num_iter}={success / num_iter * 100}%')
print(f'total time run={(time() - t_start) / 60} minutes for {icb} inversions')
